/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.extension.plugins.inner;

import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.MappedStatement;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;

/**
 * 主键生成拦截器-用于批量插入
 * @author wanglei
 * @since 4.0.0
 */
public class KeyGeneratorInterceptor implements InnerInterceptor {
    private static final Log logger = LogFactory.getLog(KeyGeneratorInterceptor.class);



    public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        if (ms.getId().contains(".insertBatch")) {
            if (parameter instanceof Map) {
                Map<String, Object> param = (Map<String, Object>) parameter;
                if (param.containsKey(Constants.COLLECTION)) {
                    List collection = (List) param.get(Constants.COLLECTION);
                    initPkey(collection);
                }
            }
        }
    }


    /**
     * 通过 序列初始化主键
     *
     * @param collection 需要初始化主键的对象
     */
    public void initPkey(List collection) {
        if (collection == null || collection.size() == 0) {
            return;
        }
        //拿到对象
        Object obj = collection.get(0);
        if (obj == null) {
            return;
        }
        TableInfo tableInfo = TableInfoHelper.getTableInfo(obj.getClass());
        // 没有key生成器，没有序列不处理
        if (tableInfo == null || tableInfo.getKeySequence() == null || tableInfo.getKeyGenerator() == null) {
            return;
        }
        //主键不是long也不是int    不处理
        if (tableInfo.getKeyType() == null || !Integer.class.equals(tableInfo.getKeyType()) && !Long.class.equals(tableInfo.getKeyType())) {
            return;
        }
        //拿到sql
        String selectSql = tableInfo.getKeyGenerator().executeSql(tableInfo.getKeySequence().value());
        Field keyField = tableInfo.getKeyField();
        keyField.setAccessible(true);
        Connection connection = SqlHelper.sqlSession(obj.getClass()).getConnection();
        for (Object po : collection) {
            try {
                keyField.set(po, getPkey(selectSql, tableInfo.getKeyType(), connection));
            } catch (IllegalAccessException e) {
                logger.error("Error:", e);
            }
        }
    }

    /**
     * 获取主键
     *
     * @param sql        查询sql
     * @param keyType    主键类型
     * @param connection 数据库连接
     * @param <T>        主键类型
     * @return 主键
     */
    public <T> T getPkey(String sql, Class<T> keyType, Connection connection) {
        PreparedStatement countStmt = null;
        ResultSet rs = null;
        Object pkey = null;
        try {
            countStmt = connection.prepareStatement(sql);
            rs = countStmt.executeQuery();
            if (rs.next()) {
                if (keyType.equals(Integer.class)) {
                    pkey = rs.getInt(1);
                } else {
                    pkey = rs.getLong(1);
                }
            }

        } catch (SQLException e) {
            logger.error("Error:", e);
        } finally {
            try {
                rs.close();
                countStmt.close();
            } catch (SQLException e) {
                logger.error("Error:", e);
            }
        }
        return (T) pkey;
    }

}
